{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import torch.nn as nn\n",
    "from scipy import sparse\n",
    "from torch.utils.data import DataLoader,Dataset\n",
    "import torch.nn.functional as F\n",
    "import torch.nn.init as init\n",
    "from scipy.linalg import expm\n",
    "import math\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data():\n",
    "    dise_sim = np.load('data\\disease_disease.npy')#138*138\n",
    "    c_d = np.load('data\\circRNA_disease.npy')#834*138\n",
    "    c_m = np.load('data\\circRNA_miRNA.npy')#834*555\n",
    "    m_d = np.load('data\\disease_miRNA.npy').T #555*138\n",
    "    return torch.tensor(dise_sim,dtype=float),torch.tensor(c_d,dtype=torch.long),torch.tensor(c_m,dtype=torch.long),torch.tensor(m_d,dtype=torch.long)\n",
    "\n",
    "import itertools\n",
    "def calculate_sim(cd,dd):\n",
    "    s1=cd.shape[0]\n",
    "    ll=torch.eye(s1)\n",
    "    m2=dd*cd[:,None,:]\n",
    "    m1=cd[:,:,None]\n",
    "    for x,y in itertools.permutations(torch.linspace(0,s1-1,s1,dtype=torch.long),2):\n",
    "        x,y=x.item(),y.item()\n",
    "        m=m1[x,:,:]*m2[y,:,:]\n",
    "        if cd[x].sum()+cd[y].sum()==0:\n",
    "            ll[x,y]=0\n",
    "        else:\n",
    "            ll[x,y]=(m.max(dim=0,keepdim=True)[0].sum()+m.max(dim=1,keepdim=True)[0].sum())/(cd[x].sum()+cd[y].sum())\n",
    "    return ll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_dataset(c_d,dise_sim):\n",
    "    cd_axix = torch.where(c_d==1) #989   表示989对关联\n",
    "    circ_x = cd_axix[0] #989\n",
    "    dise_y = cd_axix[1] #989\n",
    "    #有顺序的坐标\n",
    "    positive_sample = torch.cat([circ_x[:,None],dise_y[:,None]],dim=1) #989*2\n",
    "    rand_index1 = torch.randperm(c_d.sum()) #0~988\n",
    "    #打乱顺序\n",
    "    ps = torch.index_select(positive_sample,dim=0,index=rand_index1) #989*2，在上面的positive_sample中按列随机选中\n",
    "\n",
    "    #负样本\n",
    "    cd_axix_neg = torch.where(c_d==0) \n",
    "    circ_x_neg = cd_axix_neg[0]  #114103\n",
    "    dise_y_neg = cd_axix_neg[1]  #114103\n",
    "    #有顺序的坐标\n",
    "    negative_sample = torch.cat([circ_x_neg[:,None],dise_y_neg[:,None]],dim=1) #114103*2\n",
    "    rand_index2 = torch.randperm(negative_sample.shape[0])\n",
    "    ns = torch.index_select(input=negative_sample,dim=0,index=rand_index2) #114103*2    \n",
    "\n",
    "    #取与正样本相同量的负样本\n",
    "    pos_len = ps.shape[0]\n",
    "    train_neg = ns[0:pos_len,:] #989*2\n",
    "    train_dataset = torch.cat([ps.T,train_neg.T],dim=1) #2*(989*2）\n",
    "\n",
    "    test_dataset = torch.cat([ps.T,ns.T],dim=1)#2*(114103+989)\n",
    "\n",
    "    #计算circ_sim\n",
    "    c_sim = calculate_sim(c_d,dise_sim)\n",
    "    return train_dataset,test_dataset,c_sim\n",
    "\n",
    "dise_sim,c_d,c_m,m_d = load_data()\n",
    "train_dataset,test_dataset,c_sim = split_dataset(c_d,dise_sim)\n",
    "mi_sim = calculate_sim(m_d,dise_sim)\n",
    "\n",
    "#构造无遮盖邻接矩阵\n",
    "circ = torch.cat([c_sim,c_d,c_m],dim=1)\n",
    "dise = torch.cat([c_d.T,dise_sim,m_d.T],dim=1)\n",
    "micr = torch.cat([c_m.T,m_d,mi_sim],dim=1)\n",
    "uncover_adj = torch.cat([circ,dise,micr],dim=0)\n",
    "print(uncover_adj.shape)\n",
    "torch.save(train_dataset,rf'data\\caseStudy\\train_dataset')\n",
    "torch.save(test_dataset,rf'data\\caseStudy\\test_dataset')\n",
    "torch.save(uncover_adj,rf'data\\caseStudy\\uncover_adj')\n",
    "torch.save(c_d,rf'data\\caseStudy\\c_d')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_matrix = uncover_adj\n",
    "circ_sim = adj_matrix[0:834,0:834]\n",
    "c_d = adj_matrix[0:834,834:834+138]\n",
    "\n",
    "di_sim = adj_matrix[834:834+138,834:834+138]\n",
    "    \n",
    "c_c_c = torch.matmul(circ_sim,circ_sim)\n",
    "c_d_c = torch.matmul(c_d,c_d.T)\n",
    "    \n",
    "d_c_d = torch.matmul(c_d.T,c_d)\n",
    "d_d_d = torch.matmul(di_sim,di_sim)\n",
    "torch.save(c_d,rf'data\\caseStudy\\cd.pth')\n",
    "torch.save(circ_sim,rf'data\\caseStudy\\cc.pth')\n",
    "torch.save(di_sim,rf'data\\caseStudy\\dd.pth')\n",
    "torch.save(c_c_c,rf'data\\caseStudy\\ccc.pth')\n",
    "torch.save(c_d_c,rf'data\\caseStudy\\cdc.pth')\n",
    "torch.save(d_c_d,rf'data\\caseStudy\\dcd.pth')\n",
    "torch.save(d_d_d,rf'data\\caseStudy\\ddd.pth')"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
